{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "94577b80-1b93-4324-a9bd-a682364b9703",
   "metadata": {},
   "source": [
    "This is an example of train and test pipeline for PaiNN model from schnetpack library.  \n",
    "Same task could be performed with pre-defined config from repository root:\n",
    "```bash\n",
    "python run.py --config-name painn.yaml\n",
    "```\n",
    "For detailed description please refer to [README](../nablaDFT/README.md).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6f3864e",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Train/test cycles example using PaiNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c4edcbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Based on https://github.com/atomistic-machine-learning/schnetpack/blob/master/examples/tutorials/tutorial_02_qm9.ipynb\n",
    "import os\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "import schnetpack as spk\n",
    "import schnetpack.representation as rep\n",
    "import schnetpack.task as task\n",
    "import schnetpack.transform as trn\n",
    "import torch\n",
    "import torchmetrics\n",
    "from nablaDFT import model_registry\n",
    "from nablaDFT.ase_model import AtomisticTaskFixed\n",
    "from nablaDFT.dataset import ASENablaDFT, dataset_registry\n",
    "from nablaDFT.dataset.split import TestSplit\n",
    "from nablaDFT.utils import seed_everything\n",
    "from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a01bf1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"dataset_train_tiny\"  # Name of the training dataset\n",
    "datapath = \"database\"  # Path to the selected dataset\n",
    "nepochs = 200  # Number of epochs to train for\n",
    "seed = 1799  # Random seed number for reproducibility\n",
    "batch_size = 32  # Size of each batch for training\n",
    "train_ratio = 0.9  # Part of dataset used for training\n",
    "val_ratio = 0.1  # Part of dataset used for validation\n",
    "n_interactions = 6  # Number of interactions to consider between atoms\n",
    "n_atom_basis = 128  # Number of basis functions for atoms in the representation\n",
    "n_rbf = 100  # Number of radial basis functions in the representation\n",
    "cutoff = 5.0  # Cutoff distance (in Bohr) for computing interactions\n",
    "devices = 1  # Number of GPU/TPU/CPU devices to use for training\n",
    "transforms = [\n",
    "    trn.ASENeighborList(cutoff=cutoff),\n",
    "    trn.RemoveOffsets(\"energy\", remove_mean=True, remove_atomrefs=False),\n",
    "    trn.CastTo32(),\n",
    "]  # data transforms used for training and validation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0683997-02ea-4e82-a83a-91b245b1c1e6",
   "metadata": {},
   "source": [
    "This example uses *tiny* train dataset split.  \n",
    "All available dataset splits could be listed with:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "513620bb-db0f-4410-9af0-1ed70b3e1249",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_registry.list_datasets(\"energy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6498b248",
   "metadata": {},
   "source": [
    "## Downloading dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a822721",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_everything(seed)\n",
    "workpath = logspath\n",
    "\n",
    "if not os.path.exists(workpath):\n",
    "    os.makedirs(workpath)\n",
    "\n",
    "datamodule = ASENablaDFT(\n",
    "    dataset_name=\"dataset_train_tiny\",\n",
    "    split=\"train\",\n",
    "    root=\"database\",\n",
    "    batch_size=batch_size,\n",
    "    train_ratio=train_ratio,\n",
    "    val_ratio=val_ratio,\n",
    "    split_file=None,\n",
    "    num_workers=4,\n",
    "    train_transforms=transforms,\n",
    "    val_transforms=transforms,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b208fcef",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Initializing training procedure and starting training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "028a4f5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pairwise_distance = spk.atomistic.PairwiseDistances()\n",
    "radial_basis = spk.nn.radial.GaussianRBF(n_rbf=n_rbf, cutoff=cutoff)\n",
    "cutoff_fn = spk.nn.cutoff.CosineCutoff(cutoff)\n",
    "representation = rep.PaiNN(\n",
    "    n_interactions=n_interactions, n_atom_basis=n_atom_basis, radial_basis=radial_basis, cutoff_fn=cutoff_fn\n",
    ")\n",
    "pred_energy = spk.atomistic.Atomwise(n_in=representation.n_atom_basis, output_key=\"energy\")\n",
    "pred_forces = spk.atomistic.Forces()\n",
    "postprocessors = [trn.AddOffsets(\"energy\", add_mean=True)]\n",
    "nnpot = spk.model.NeuralNetworkPotential(\n",
    "    representation=representation,\n",
    "    input_modules=[pairwise_distance],\n",
    "    output_modules=[pred_energy, pred_forces],\n",
    "    postprocessors=postprocessors,\n",
    ")\n",
    "output_energy = spk.task.ModelOutput(\n",
    "    name=\"energy\", loss_fn=torch.nn.MSELoss(), loss_weight=1, metrics={\"MAE\": torchmetrics.MeanAbsoluteError()}\n",
    ")\n",
    "output_forces = spk.task.ModelOutput(\n",
    "    name=\"forces\", loss_fn=torch.nn.MSELoss(), loss_weight=1, metrics={\"MAE\": torchmetrics.MeanAbsoluteError()}\n",
    ")\n",
    "\n",
    "scheduler_args = {\"factor\": 0.8, \"patience\": 10, \"min_lr\": 1e-06}\n",
    "\n",
    "task = AtomisticTaskFixed(\n",
    "    model_name=\"PaiNN\",\n",
    "    model=nnpot,\n",
    "    outputs=[output_energy, output_forces],\n",
    "    optimizer_cls=torch.optim.AdamW,\n",
    "    optimizer_args={\"lr\": 1e-4},\n",
    "    scheduler_cls=ReduceLROnPlateau,\n",
    "    scheduler_args=scheduler_args,\n",
    "    scheduler_monitor=\"val_loss\",\n",
    ")\n",
    "\n",
    "# create trainer\n",
    "logger = pl.loggers.TensorBoardLogger(save_dir=workpath)\n",
    "lr_monitor = LearningRateMonitor(logging_interval=\"step\")\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    save_top_k=1,\n",
    "    monitor=\"val_loss\",\n",
    "    mode=\"min\",\n",
    "    dirpath=f\"{workpath}/checkpoints\",\n",
    "    filename=\"Painn-{epoch:03d}_{val_loss:4f}\",\n",
    ")\n",
    "callbacks = [lr_monitor, checkpoint_callback]\n",
    "\n",
    "trainer = pl.Trainer(\n",
    "    accelerator=\"gpu\",\n",
    "    devices=devices,\n",
    "    callbacks=callbacks,\n",
    "    logger=logger,\n",
    "    default_root_dir=workpath,\n",
    "    max_epochs=nepochs,\n",
    ")\n",
    "\n",
    "trainer.fit(task, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3e51125-0334-4db5-9774-14b2c84bc2f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = trainer.checkpoint_callback.best_model_path"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "054f30f4",
   "metadata": {},
   "source": [
    "## Initializing the testing procedure and computing the metric's result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34793757-42cb-4957-b0aa-9ce63693a71d",
   "metadata": {},
   "source": [
    "We will use pretrained model for test with *model_registry* object.  \n",
    "You could list all available pretrained model with:  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13461e19-4fa4-4426-b0d2-4ac0a948acc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_registry.list_models()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cfe5514",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "cutoff = 5.0\n",
    "gpu = 0\n",
    "\n",
    "if gpu == -1:\n",
    "    device = \"cpu\"\n",
    "else:\n",
    "    device = 1\n",
    "\n",
    "datamodule_test = ASENablaDFT(\n",
    "    dataset_name=\"dataset_test_conformations_tiny\",\n",
    "    split=\"test\",\n",
    "    root=\"database_test\",\n",
    "    batch_size=batch_size,\n",
    "    train_ratio=0.0,\n",
    "    val_ratio=0.0,\n",
    "    test_ratio=1.0,\n",
    "    num_workers=4,\n",
    "    split_file=None,\n",
    "    test_transforms=[trn.ASENeighborList(cutoff=cutoff), trn.CastTo32()],\n",
    "    splitting=TestSplit(),\n",
    ")\n",
    "\n",
    "trainer = pl.Trainer(accelerator=\"gpu\", devices=device, inference_mode=False, logger=False)\n",
    "model = model_registry.get_pretrained_model(\"lightning\", \"PaiNN_train_tiny\")\n",
    "\n",
    "trainer.test(model=model, datamodule=datamodule_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
